import torch
import torch.utils.data
import logging
import random
import numpy as np
import time
from datetime import timedelta
import copy

CIFAR10_classes = ["airplane", "automobile", "bird",  "cat",  "deer",
				   "dog",      "frog",       "horse", "ship", "truck"]
CIFAR10_classes = ["0", "1", "2",  "3",  "4",
				   "5",      "6",       "7", "8", "9"]

class LogFormatter:
	def __init__(self):
		self.start_time = time.time()

	def format(self, record):
		elapsed_seconds = round(record.created - self.start_time)

		prefix = "%s - %s - %s" % (
			record.levelname,
			time.strftime("%x %X"),
			timedelta(seconds=elapsed_seconds),
		)
		message = record.getMessage()
		message = message.replace("\n", "\n" + " " * (len(prefix) + 3))
		return "%s - %s" % (prefix, message) if message else ""

def create_logger(filepath, rank):
	# create log formatter
	log_formatter = LogFormatter()

	# create file handler and set level to debug
	if filepath is not None:
		if rank > 0:
			filepath = "%s-%i" % (filepath, rank)
		file_handler = logging.FileHandler(filepath, "a")
		file_handler.setLevel(logging.DEBUG)
		file_handler.setFormatter(log_formatter)

	# create console handler and set level to info
	console_handler = logging.StreamHandler()
	console_handler.setLevel(logging.INFO)
	console_handler.setFormatter(log_formatter)

	# create logger and set level to debug
	logger = logging.getLogger()
	logger.handlers = []
	logger.setLevel(logging.DEBUG)
	logger.propagate = False
	if filepath is not None:
		logger.addHandler(file_handler)
	logger.addHandler(console_handler)

	# reset logger elapsed time
	def reset_time():
		log_formatter.start_time = time.time()

	logger.reset_time = reset_time

	return logger

def setup_seed(seed):
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)
	random.seed(seed)
	torch.backends.cudnn.deterministic = True
	np.random.seed(seed)

def split_dataset(dataset, total_number):
	target_list = np.asarray(dataset.targets).tolist()
	total_data = len(dataset.data)
	num_classes = len(list(set(target_list)))
	print(num_classes)
	new_size = total_number

	num_pre_class_new = new_size // num_classes

	split_idx = {}
	dict_new = {}

	for i in range(num_classes):
		dict_new[i] = 0
		split_idx[i] = []

	random_idx = [i for i in range(total_data)]
	#np.random.shuffle(random_idx)
	for i in range(num_classes):
		for idx in random_idx:
			if target_list[idx] == i:
				if dict_new[i] < num_pre_class_new:
					split_idx[i].append(idx)
					dict_new[i] += 1
	data_list = []
	label_list = []
	for i in range(num_classes):
		#print(len(split_idx[i]))
		data_list.append(dataset.data[split_idx[i]])
		label_list.append(np.asarray(dataset.targets)[split_idx[i]])
	new_dataset = copy.deepcopy(dataset)
	new_dataset.data = np.vstack(data_list)
	new_dataset.targets = np.vstack(label_list).reshape(new_dataset.data.shape[0],).tolist()
	#print(new_dataset.data.shape)
	#print(new_dataset.targets)
	return new_dataset


